Skip to content

[JAX] Support for cuDNN-backed flex attention#2985

Open
vcherepanov-nv wants to merge 14 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax
Open

[JAX] Support for cuDNN-backed flex attention#2985
vcherepanov-nv wants to merge 14 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

@vcherepanov-nv vcherepanov-nv commented May 13, 2026

Description

Adds experimental JAX fused-attention score_mod support through cuDNN frontend SDPA graphs.

This introduces a score_mod(graph, score, tensors) callback path for fused_attn, plus optional score_mod_bprop(graph, dscore, tensors) support for backward. The Python side builds and serializes cuDNN frontend forward/backward graphs, caches graph metadata with stable callback keys, supports auxiliary tensor operands, and supports Python/NumPy scalar operands as cuDNN pass-by-value tensors. The C++ JAX extension deserializes and caches the graphs per device, then executes them through new forward/backward FFI handlers.

The Flax API now plumbs score_mod through DotProductAttention, MultiHeadAttention, and TransformerLayer. Packed QKV/KV layouts are unpacked to the separate BSHD layout when score modification is requested.

Users are responsible for supplying a mathematically correct score_mod_bprop for the corresponding score_mod; Transformer Engine wires the callback into the cuDNN graph but does not validate gradient semantics.

Current score_mod limitations:

  • Requires fused attention to be enabled.
  • Supports separate rank-4 BSHD_BSHD_BSHD Q/K/V tensors only.
  • Supports FP16/BF16 Q/K/V tensors.
  • Mutually exclusive with attention bias, masks, sequence descriptors, dropout, sliding-window attention, packed/ragged metadata, context parallelism, and non-vanilla softmax/softmax offset.
  • Requires matching cuDNN frontend Python package and C++ headers.

Fixes # (issue)
#2492

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • A new score_mod code path for the JAX FusedAttention backend
  • cuDNN frontend graph serialization and JAX FFI execution for score_mod forward/backward
  • Flax plumbing for DotProductAttention, MultiHeadAttention, and TransformerLayer
  • Tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR adds experimental cuDNN-frontend-backed flex attention (score_mod) to the JAX backend, including graph serialization/deserialization, a Python-level graph cache, new C++ FFI handlers for forward and backward, and Flax plumbing through DotProductAttention, MultiHeadAttention, and TransformerLayer.

  • Core path: fused_attn short-circuits to a new _fused_attn_score_mod custom_vjp primitive when score_mod is provided; the Python side builds and serializes cuDNN frontend graphs at trace time (cached by shape/dtype/config key), then passes the serialized bytes + UID maps as static FFI attributes to C++ handlers that deserialize and execute them.
  • C++ side: Two new FFI handlers deserialize graphs on demand into a process-lifetime unordered_map guarded by a mutex, with a thread-local cuDNN handle cache; the current double-checked locking leaves a window for redundant concurrent deserialization.
  • Flax plumbing: Packed and KV-packed layouts are transparently converted to separate BSHD tensors before the score_mod path; score_mod_tensors / score_mod_bprop_tensors are forwarded as call-time arguments to keep tensor operands in the JAX computation graph.

Confidence Score: 5/5

Safe to merge as an experimental feature; all flagged items are non-blocking quality improvements with no correctness impact.

The core forward and backward graph building, caching, FFI dispatch, and Flax plumbing are all structurally correct. Cache key stability, UID ordering, and pytree gradient structure are handled properly. The findings are race conditions that produce at worst redundant work (not wrong results) and a shutdown-order concern for the thread-local cuDNN handle that matches patterns already present elsewhere in the codebase.

transformer_engine/jax/csrc/extensions/attention.cpp (double-checked locking in GetScoreModGraph, thread-local handle destructor ordering) and transformer_engine/jax/cpp_extensions/flex_attention.py (Python-level cache lock).

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/flex_attention.py New 967-line file implementing cuDNN frontend score_mod graph building, caching, and FFI dispatch; implements a stable cache-key scheme and separates tensor vs. scalar operands cleanly.
transformer_engine/jax/csrc/extensions/attention.cpp Adds 251 lines for C++ cuDNN graph deserialization, thread-local handle cache, and two new FFI handlers (forward/backward); double-checked locking leaves redundant deserializations possible under thread contention.
transformer_engine/jax/attention.py Adds custom_vjp wrapper for score_mod path with correct residual propagation, early-return before the deprecated sequence_descriptor path, and proper validation delegation.
transformer_engine/jax/flax/transformer.py Plumbs score_mod/score_mod_bprop through DotProductAttention, MultiHeadAttention, and TransformerLayer; handles packed/kvpacked layout unpacking correctly before the score_mod path.
tests/jax/test_fused_attn_score_mod.py New 671-line test suite covering causal masking, post-scale bias, softcap (forward/backward), and Flax layer integration, with reference implementations for correctness comparison.

Sequence Diagram

sequenceDiagram
    participant User
    participant fused_attn
    participant ScoreMod as "_fused_attn_score_mod"
    participant FlexPy as "flex_attention.py"
    participant FFI as "FFI/XLA"
    participant Cpp as "C++ Handler"
    participant Cache as "cuDNN Graph Cache"

    User->>fused_attn: "call with score_mod callback"
    fused_attn->>fused_attn: "validate_fused_attn_score_mod()"
    fused_attn->>FlexPy: "make_fused_attn_score_mod_config()"
    fused_attn->>ScoreMod: "custom_vjp forward"

    Note over ScoreMod,FlexPy: JAX Tracing Phase
    ScoreMod->>FlexPy: "fused_attn_score_mod_fwd()"
    FlexPy->>FlexPy: "check _score_mod_graph_cache"
    alt cache miss
        FlexPy->>FlexPy: "_build_score_mod_fwd_graph()"
        FlexPy->>FlexPy: "store in _score_mod_graph_cache"
    end
    FlexPy->>FFI: "ffi.ffi_call(serialized_graph, uids)"

    Note over FFI,Cache: XLA Execution Phase
    FFI->>Cpp: "FusedAttnScoreModForwardFFI(stream, q, k, v)"
    Cpp->>Cache: "GetScoreModGraph(stream, attrs)"
    alt C++ cache miss
        Cache->>Cache: "graph->deserialize(handle, data)"
        Cache->>Cache: "store shared_ptr in map"
    end
    Cpp->>Cpp: "graph->execute(handle, variant_pack)"
    Cpp-->>FFI: "output, stats, workspace"

    Note over ScoreMod,FlexPy: Backward pass
    ScoreMod->>FlexPy: "fused_attn_score_mod_bwd(qkv, o, dO, stats)"
    FlexPy->>FFI: "ffi.ffi_call(serialized_bwd_graph)"
    FFI->>Cpp: "FusedAttnScoreModBackwardFFI(...)"
    Cpp-->>FFI: "dq, dk, dv"
Loading

Reviews (9): Last reviewed commit: "Skip softcap score-mod test before SM90" | Re-trigger Greptile

Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread tests/jax/test_fused_attn.py Outdated
vcherepanov-nv and others added 2 commits May 15, 2026 03:35
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/attention.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/attention.py Outdated
vcherepanov-nv and others added 2 commits May 18, 2026 23:54
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
vcherepanov-nv and others added 2 commits May 19, 2026 00:32
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/attention.py
Comment thread transformer_engine/jax/attention.py
Comment thread transformer_engine/jax/attention.py Outdated
Comment thread transformer_engine/jax/attention.py Outdated
Comment thread transformer_engine/jax/attention.py
Comment thread tests/jax/test_distributed_fused_attn.py
Comment thread tests/jax/test_distributed_fused_attn.py
Comment thread tests/jax/test_distributed_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/jax/attention.py
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 21, 2026
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.16.0 community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants